{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "72311700-326d-4418-afcf-8246d368c51c",
   "metadata": {},
   "source": [
    "# Strain clustering with Bacformer tutorial\n",
    "\n",
    "This tutorial outlines how one can finetune Bacformer model to cluster strains. Bacformer outputs contextual protein embeddings and we use the average of all\n",
    "contextual protein embeddings as a genome embedding.\n",
    "\n",
    "We use a small random sample of 30 genomes across 4 distinct species and 3 families to demonstrate how we can embed the genomes with Bacformer and use it for clustering.\n",
    "The genomes have been extracted from [MGnify](https://www.ebi.ac.uk/metagenomics).\n",
    "\n",
    "Before you start, make sure you have `bacformer` installed (see README.md for details) and execute the notebook on a machine with GPU.\n",
    "\n",
    "## Step 1: Import required dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e456f1c3-432c-47c4-b31e-267586f12f39",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/maciejwiatrak/miniconda3/envs/bacformer-release/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "WARNING: faESM (fast ESM) not installed, this will lead to significant slowdown. Defaulting to use HuggingFace implementation. Please consider installing faESM: https://github.com/pengzhangzhi/faplm\n"
     ]
    }
   ],
   "source": [
    "import anndata as ad\n",
    "import numpy as np\n",
    "import scanpy as sc\n",
    "from bacformer.pp import embed_dataset_col\n",
    "from datasets import load_dataset\n",
    "from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score, silhouette_score\n",
    "from sklearn.preprocessing import LabelEncoder"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "461dfa81-f150-42d2-918d-adf0b7a5f41f",
   "metadata": {},
   "source": [
    "## Step 2: Load the dataset\n",
    "\n",
    "Load the sample dataset from HuggingFace."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c926b79-0a00-4844-8437-7d22e8fc2396",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load the dataset\n",
    "dataset = load_dataset(\"macwiatrak/strain-clustering-protein-sequences-sample\", split=\"train\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11557de4-2863-4e4f-86b7-4dbda999043f",
   "metadata": {},
   "source": [
    "## Step 3: Compute Bacformer embeddings\n",
    "\n",
    "Convert the protein sequences to genome embeddings. This is done in 2 steps:\n",
    "1. Embed the protein sequences with the base pLM model which is [ESM-2 t12 35M](https://huggingface.co/facebook/esm2_t12_35M_UR50D).\n",
    "2. Use the protein embeddings as input to the Bacformer model which computes contextual protein embeddings and takes the average of them to get genome embedding.\n",
    "\n",
    "This step takes ~2 min on a single A100 NVIDIA GPU with `flash-attention` installed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42ead7da-ec6d-42a3-90f5-f11079a75e25",
   "metadata": {},
   "outputs": [],
   "source": [
    "# embed the protein sequences with Bacformer\n",
    "dataset = embed_dataset_col(\n",
    "    dataset=dataset,\n",
    "    model_path=\"macwiatrak/bacformer-masked-MAG\",\n",
    "    max_n_proteins=9000,\n",
    "    genome_pooling_method=\"mean\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef86950d-9354-47a6-9c22-ed94c74cafbb",
   "metadata": {},
   "source": [
    "## Step 4: Cluster the genome embeddings\n",
    "\n",
    "We use [scanpy](https://scanpy.readthedocs.io/en/stable/) for clustering, so we convert the data to an `AnnData` object and use it to compute the `UMAP`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71cb8758-2a20-40f9-9697-b9f5fcf6254d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# convert dataset to pandas DataFrame\n",
    "df = dataset.to_pandas()\n",
    "\n",
    "# create anndata object needed for clustering\n",
    "embeddings = np.stack(df[\"embeddings\"].tolist())  # get embedding matrix\n",
    "adata = ad.AnnData(\n",
    "    X=embeddings,\n",
    "    obs=df.drop(columns=[\"embeddings\"]).copy(),\n",
    ")\n",
    "\n",
    "# compute neighbors witg scanpy\n",
    "sc.pp.neighbors(adata, use_rep=\"X\")\n",
    "\n",
    "# compute UMAP\n",
    "sc.tl.umap(adata)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2fc89db4-16b5-401f-82eb-34a56f86b33f",
   "metadata": {},
   "source": [
    "## Step 5: Plot UMAPs\n",
    "\n",
    "Plot UMAPs by species and family labels."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27bee063-b121-4a7e-85ac-d2b9f4067020",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot UMAP by species\n",
    "sc.pl.umap(adata, color=\"species\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98d32a90-a4b5-49d3-b7d0-8a1732b6457b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot UMAP by family\n",
    "sc.pl.umap(adata, color=\"family\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0ca5dedf-67d8-486b-8759-fac7760edfa3",
   "metadata": {},
   "source": [
    "## [Optional] Step 6: Compute clustering metrics\n",
    "\n",
    "Compute `Leiden` clustering and compute the metrics, useful for evaluating how well does the model cluster strains by label (here, species)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e153465-a963-4ff2-84d9-7f059b4dbb87",
   "metadata": {},
   "outputs": [],
   "source": [
    "# compute clustering metrics (optional)\n",
    "sc.tl.leiden(adata, resolution=0.1, key_added=\"leiden_clusters\")\n",
    "\n",
    "# Convert Leiden cluster labels to integer labels\n",
    "leiden_clusters = adata.obs[\"leiden_clusters\"].astype(int)\n",
    "\n",
    "# Encode ground-truth labels\n",
    "label_encoder = LabelEncoder()\n",
    "numeric_labels = label_encoder.fit_transform(adata.obs[\"species\"])\n",
    "\n",
    "# Compute ARI, NMI, and Silhouette\n",
    "ari = adjusted_rand_score(numeric_labels, leiden_clusters)\n",
    "nmi = normalized_mutual_info_score(numeric_labels, leiden_clusters)\n",
    "# Silhouette requires sample-level features + predicted labels\n",
    "sil = silhouette_score(adata.X, leiden_clusters)\n",
    "print(f\"ARI: {ari:.3f}, NMI: {nmi:.3f}, Silhouette Score: {sil:.3f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e3769487-8664-4c85-b55d-735932a4a1e5",
   "metadata": {},
   "source": [
    "----------------------\n",
    "\n",
    "#### Voilà, you made it 👏! \n",
    "\n",
    "In case of any issues or questions raise an issue on github - https://github.com/macwiatrak/Bacformer/issues."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "666d8d81-965d-4366-9948-7956cf393b0b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
